#!/bin/bash
#  pfastq-dump: parallelize fastq-dump
#  Original python implementation: https://github.com/rvalieris/parallel-fastq-dump
set -e

# pfastq-dump version
VERSION="0.1.6"

# default arguments
TMPDIR="$( pwd -P )"
OUTDIR="$( pwd -P )"
NTHREADS=1

# functions for version and help
print_version(){
  echo "pfastq-dump version $VERSION using fastq-dump " `fastq-dump --version | awk '$1 ~ /^fastq/ { print $3 }'`
}

print_help(){
  cat <<EOS
Usage:
pfastq-dump --threads <number of threads> [options] <path> [<path>...]

pfastq-dump option:
-t|--threads         number of threads
-s|--sra-id          SRA id
-O|--outdir          output directory
--tmpdir             temporary directory

Extra arguments will be passed through

Other:
-v|--version         show version info
-h|--help            program usage
EOS
}

# parse arguments
while [[ $# -gt 0 ]]; do
  key=${1}
  case ${key} in
    -v|--version)
      print_version
      exit 0
      ;;
    -h|--help)
      print_help
      exit 0
      ;;
    -t|--threads)
      NTHREADS=${2}
      shift
      ;;
    -s|--sra-id)
      SRAID=${2}
      shift
      ;;
    -O|--outdir)
      OUTDIR="${2}"
      shift
      ;;
    --tmpdir)
      TMPDIR="${2}"
      shift
      ;;
    -Z|--stdout)
      STDOUT="true"
      ;;
    -*|--*)
      OPTIONS="${OPTIONS} ${1}"
      ;;
    *sra)
      INPUT_FILES="${INPUT_FILES} ${1}"
      ;;
  esac
  shift
done

# function to exec multiple fastq dump
parallel_fastq_dump(){
  local sra="${1}"
  local count="${2}"

  local sraid=$(basename ${sra} | sed -e 's:.sra$::')
  local td="${TMPDIR}/pfd.tmp/${sraid}"

  local avg=$(echo $((${count} / ${NTHREADS})))
  local remain=$(echo $((${count} % ${NTHREADS})))

  local out=()
  local last=1

  # Calculate blocks per thread
  for i in $(seq ${NTHREADS}); do
    local spots=$((${last} + ${avg} -1))
    if [[ ${i} == ${NTHREADS} ]]; then
      plus_remain=$((${spots} + ${remain}))
      out+=("${last},${plus_remain}")
    else
      out+=("${last},${spots}")
    fi
    last=$((${last} + ${avg}))
  done
  echo "blocks: ${out[@]}" >&2

  # Run fastq-dump and send it to background
  for min_max in "${out[@]}"; do
    local min=$(echo ${min_max} | cut -d ',' -f 1)
    local max=$(echo ${min_max} | cut -d ',' -f 2)
    local idx=$(echo $(($min / $avg + 1)))

    local d="${td}/${idx}"
    mkdir -p ${d}

    fastq-dump -N ${min} -X ${max} -O ${d} ${OPTIONS} ${sra} 1>&2 &
  done

  # Wait subprocesses and collect exit status
  local failure=0
  for job in $(jobs -p); do
    wait ${job} || let "failure+=1"
  done

  # Marge splitted fastq files
  if [[ "${failure}" != "0" ]]; then
    echo "ERROR during decompressing" >&2
    exit 1
  else
    files=($(ls "${td}/1"))

    # Create empty file to collect splitted fastq reads (if not --stdout)
    if [[ "${STDOUT}" != "true" ]]; then
      for fo in "${files[@]}"; do
        touch "${OUTDIR}/${fo}"
      done
    fi

    for min_max in "${out[@]}"; do
      local min=$(echo ${min_max} | cut -d ',' -f 1)
      local idx=$(echo $(($min / $avg + 1)))

      for fo in "${files[@]}"; do
        if [[ "${STDOUT}" == "true" ]]; then
          cat "${td}/${idx}/${fo}"
        else
          cat "${td}/${idx}/${fo}" >> "${OUTDIR}/${fo}"
        fi
      done
    done
  fi
}

# function to calculate spot count using sra-stat
calc_spot_count(){
  local sra="${1}"
  local txt=$(sra-stat --meta --quick "${sra}")
  local total=0
  for line in ${txt}; do
    local n=$(echo ${line} | cut -d '|' -f 3 | cut -d ':' -f 1)
    local total=$(( ${total} + ${n} ))
  done
  echo ${total}
}

# function to check prerequisites
check_binary_location(){
  local cmd="${1}"
  local cmd_path=$(which ${cmd} 2>/dev/null ||:)
  if [[ ! -e "${cmd_path}" ]]; then
    echo "ERROR: ${cmd} not found." >&2
    exit 1
  else
    echo "Using $(${cmd} --version | tr -d '\n')" >&2
  fi
}

# main function
check_binary_location "sra-stat"
check_binary_location "fastq-dump"

echo "tmpdir: ${TMPDIR}" >&2
echo "outdir: ${OUTDIR}" >&2

if [[ "${SRAID}" ]]; then
  spot_count=`calc_spot_count "${SRAID}"`
  parallel_fastq_dump "${SRAID}" "${spot_count}"

elif [[ "${INPUT_FILES}" ]]; then
  for srafile in ${INPUT_FILES}; do
    spot_count=`calc_spot_count "${srafile}"`
    parallel_fastq_dump "${srafile}" "${spot_count}"
  done

else
  print_help
  exit 0
fi

# cleaning tmpfiles
rm -fr "${TMPDIR}/pfd.tmp"
